# "An Estimation Procedure for the Hawkes Process" by Matthias Kirchner
library(gtools)
library(data.table)
library(tidyr)
library(xts)
library(glmnet)

get_theta_cls <- function(data_ts, p, state_ts=NULL, mc.cores=parallel::detectCores()){
  n <- nrow(data_ts)
  design_matrix <- lapply(1:p, function(j) data_ts[j:(n-1-p+j),])
  design_matrix <- do.call(cbind, design_matrix)
  y_response <- data_ts[-(1:p),]
  if(is.null(state_ts)){
    hawkes_kernel <- parallel::mclapply(1:ncol(data_ts), function(k)
      lm(y_response[,k] ~ design_matrix)$coef, 
      mc.cores = mc.cores)
  }else{
    hawkes_kernel <- parallel::mclapply(1:ncol(data_ts), function(k){
      if(class(state_ts)[1] %in% c("matrix", "data.frame", "data.table")){
        state <- as.factor(state_ts[-(1:p),k])  
        lm(y_response[,k] ~ 0 + design_matrix + state)$coef
      }else if(class(state_ts)[1] == "list"){
        state <- as.factor(state_ts[[k]][-(1:p),1])
        price <- as.factor(state_ts[[k]][-(1:p),2])
        lm(y_response[,k] ~ 0 + design_matrix + state + price)$coef
      }else{
        stop("state_ts is not in the right class")
      }
    }, mc.cores = mc.cores)
  }
  names(hawkes_kernel) <- colnames(y_response)
  organize_kernel_output(hawkes_kernel)
}

organize_kernel_output <- function(hawkes_kernel){
  hawkes_kernel_design_matrix <- sapply(hawkes_kernel, function(x) x[grepl("design_matrix", names(x))])
  colnames(hawkes_kernel_design_matrix) <- names(hawkes_kernel)
  mu_state <- lapply(hawkes_kernel, function(x) x[!grepl("design_matrix", names(x))])
  mu_state_names <- mixedsort(unique(unlist(lapply(mu_state, names))))
  mu_state_mat <- matrix(NA, nrow=length(mu_state_names), ncol=length(hawkes_kernel),
                         dimnames = list(mu_state_names, names(hawkes_kernel)))
  for(k in 1:length(mu_state)){
    mu_state_mat[names(mu_state[[k]]),k] <- mu_state[[k]]
  }
  hawkes_kernel <- rbind(mu_state_mat, hawkes_kernel_design_matrix)
  hawkes_kernel
}

get_hawkes_kernel_bigdt <- function(bigdt, model_config){
  y_dt <- data.matrix(bigdt[,1])
  design_matrix <- bigdt[,2:ncol(bigdt)]
  cols_keep <- colnames(design_matrix)
  cols_keep <- cols_keep[!grepl("state", cols_keep)]
  cols_keep <- cols_keep[!grepl("price", cols_keep)]
  cols_keep <- cols_keep[!grepl("time", cols_keep)]
  #  cols_keep <- cols_keep[!grepl("wgt", cols_keep)]
  design_matrix <- design_matrix[,mget(cols_keep)]
  design_matrix <- data.matrix(design_matrix)
  cols_clean_name <- colnames(design_matrix)
  cols_clean_name <- strsplit(cols_clean_name, "\\.")
  cols_clean_name <- sapply(cols_clean_name, function(x) x[length(x)])
  colnames(design_matrix) <- cols_clean_name
  state <- as.factor(bigdt[[names(bigdt)[grep("state", names(bigdt))]]])
  if(length(levels(state)) == 1){
    state <- rep(0, length(state))
  }
  # time_30min <- as.factor(bigdt[[names(bigdt)[grep("time_30min", names(bigdt))]]])
  # time_5min <- as.factor(bigdt[[names(bigdt)[grep("time_5min", names(bigdt))]]])
  # time_1min <- as.factor(bigdt[[names(bigdt)[grep("time_1min", names(bigdt))]]])
#  time_customizedone <- as.factor(bigdt[[names(bigdt)[grep("time_customizedone", names(bigdt))]]])
  time_customizedtwo <- as.factor(bigdt[[names(bigdt)[grep("time_customizedtwo", names(bigdt))]]])
#  time_customizedthree <- as.factor(bigdt[[names(bigdt)[grep("time_customizedthree", names(bigdt))]]])
  
  if (model_config == "state"){
    reg = lm(y_dt ~ 0  + state )
  }else if (model_config == "time_customizedtwo"){
    reg = lm(y_dt ~ 0  + time_customizedtwo)
  }else if (model_config == "state_time_customizedtwo"){
    reg = lm(y_dt ~ 0 + state + time_customizedtwo)
    reg$coefficients = c(reg$coefficients, setNames(0, paste0("time_customizedtwo", levels(time_customizedtwo)[1])))
  }else if (model_config == "hawkes"){
    reg = lm(y_dt ~  design_matrix)
  }else if (model_config == "hawkes_LASSO"){
    x <- model.matrix(y_dt ~ design_matrix)
    fit = glmnet(x, y_dt, alpha = 1, lambda =0.001,thresh = 1e-10)
    c = as.numeric(coef(fit))[-2]
    names(c) = rownames(coef(fit))[-2]
    residuals = as.numeric(y_dt - predict(fit, x))
    return(list(c,residuals))
  }else if (model_config == "hawkes_LASSO_0005"){
    x <- model.matrix(y_dt ~ design_matrix)
    fit = glmnet(x, y_dt, alpha = 1, lambda =0.0005,thresh = 1e-10)
    c = as.numeric(coef(fit))[-2]
    names(c) = rownames(coef(fit))[-2]
    residuals = as.numeric(y_dt - predict(fit, x))
    return(list(c,residuals))
  }else if (model_config == "state_hawkes"){
    reg = lm(y_dt ~ 0  + design_matrix + state  )
  }else if (model_config == "hawkes_time_customizedtwo"){
    reg = lm(y_dt ~ 0  + design_matrix + time_customizedtwo  )
  }else if (model_config == "state_hawkes_time_customizedtwo" ){
    reg = lm(y_dt ~ 0 + design_matrix + state  + time_customizedtwo)
    reg$coefficients = c(reg$coefficients, setNames(0, paste0("time_customizedtwo", levels(time_customizedtwo)[1])))
  }else if (model_config == "state_hawkes_time_customizedtwo_LASSO" ){
    x <- model.matrix(y_dt ~ 0+ state + time_customizedtwo + design_matrix)
    fit = glmnet(x, y_dt, alpha = 1, lambda = 0.001,intercept = 0,thresh = 1e-10,penalty.factor = c(rep(0,135), rep(1,dim(x)[2]-135)  ))
    c = as.numeric(coef(fit))
    names(c) = c("time_customizedtwo1",rownames(coef(fit))[-c(1)])
    residuals = as.numeric(y_dt - predict(fit, x))
    return(list(c,residuals))
  }else if (model_config == "state_hawkes_time_customizedtwo_LASSO_0005" ){
    x <- model.matrix(y_dt ~ 0+ state + time_customizedtwo + design_matrix)
    fit = glmnet(x, y_dt, alpha = 1, lambda = 0.0005,intercept = 0,thresh = 1e-10,penalty.factor = c(rep(0,135), rep(1,dim(x)[2]-135)  ))
    c = as.numeric(coef(fit))
    names(c) = c("time_customizedtwo1",rownames(coef(fit))[-c(1)])
    residuals = as.numeric(y_dt - predict(fit, x))
    return(list(c,residuals))
  }else if (model_config == "state_hawkes_time5min" ){
    reg = lm(y_dt ~ 0 + design_matrix + state  + time_5min)
    #    reg$coefficients = c(reg$coefficients, setNames(0, paste0("price", levels(price)[1])))
    reg$coefficients = c(reg$coefficients, setNames(0, paste0("time_5min", levels(time_5min)[1])))
  }else if (model_config == "state_hawkes_time_customizedone" ){
    reg = lm(y_dt ~ 0 + design_matrix + state  + time_customizedone)
    #    reg$coefficients = c(reg$coefficients, setNames(0, paste0("price", levels(price)[1])))
    reg$coefficients = c(reg$coefficients, setNames(0, paste0("time_customizedone", levels(time_customizedone)[1])))
  }else if (model_config == "state_hawkes_time_customizedthree" ){
    reg = lm(y_dt ~ 0 + design_matrix + state  + time_customizedthree)
    #    reg$coefficients = c(reg$coefficients, setNames(0, paste0("price", levels(price)[1])))
    reg$coefficients = c(reg$coefficients, setNames(0, paste0("time_customizedthree", levels(time_customizedthree)[1])))
  }else if (model_config == "state_hawkes_time1min" ){
    reg = lm(y_dt ~ 0 +  design_matrix + state  + time_1min)
    reg$coefficients = c(reg$coefficients, setNames(0, paste0("time_1min", levels(time_1min)[1])))
  }
  list(reg$coef, reg$residuals)
}


get_bigdt_list <- function(data_ts, p, state_ts){
  n <- nrow(data_ts)
  design_matrix <- lapply(1:p, function(j) data_ts[j:(n-1-p+j),])
  design_matrix <- do.call(cbind, design_matrix)
  y_response <- data_ts[-(1:p),]
  
  bigdt_list <- lapply(1:ncol(data_ts), function(k){
    # bigdt <- cbind(y_response[,k,drop=FALSE], design_matrix, state_ts$state[[k]][-(1:p)], state_ts$price[-(1:p)], state_ts$time_30min[-(1:p)],state_ts$time_5min[-(1:p)],state_ts$time_1min[-(1:p)],state_ts$time_customizedone[-(1:p)],state_ts$time_customizedtwo[-(1:p)],state_ts$time_customizedthree[-(1:p)])
    # colnames(bigdt)[(ncol(bigdt)-7) : ncol(bigdt)] <- c("state", "price", "time_30min", "time_5min", "time_1min", "time_customizedone", "time_customizedtwo","time_customizedthree")
#    bigdt <- cbind(y_response[,k,drop=FALSE], design_matrix, state_ts$state[[k]][-(1:p)],state_ts$time_1min[-(1:p)],state_ts$time_customizedtwo[-(1:p)])
#    colnames(bigdt)[(ncol(bigdt)-2) : ncol(bigdt)] <- c("state", "time_1min", "time_customizedtwo")
    bigdt <- cbind(y_response[,k,drop=FALSE], design_matrix, state_ts$state[[k]][-(1:p)],state_ts$time_customizedtwo[-(1:p)])
    colnames(bigdt)[(ncol(bigdt)-1) : ncol(bigdt)] <- c("state", "time_customizedtwo")
    colnames(bigdt) <- paste0("V", 1:ncol(bigdt), ".", colnames(bigdt))
    bigdt <- data.table(bigdt)
    print (dim(bigdt))
    bigdt
  })
  names(bigdt_list) <- colnames(data_ts)
  bigdt_list
}

get_theta_cls_bigdt <- function(bigdt_list, mc.cores=parallel::detectCores()){
  hawkes_kernel <- parallel::mclapply(bigdt_list, get_hawkes_kernel_bigdt, mc.cores = mc.cores)
  names(hawkes_kernel) <- names(bigdt_list)
  hawkes_kernel <- organize_kernel_output(hawkes_kernel)
  hawkes_kernel
}

get_theta_cls_bigdt_fromfile <- function(dt_file_paths, mc.cores=parallel::detectCores()){
  reg <- parallel::mclapply(dt_file_paths, function(file_path){
    bigdt <- readRDS(file_path)
    get_hawkes_kernel_bigdt(bigdt)
  }, mc.cores = mc.cores)
  hawkes_kernel = list()
  residuals = list()
  for (i in 1:length(reg)){ hawkes_kernel[[names(reg[i])]] = reg[[names(reg[i])]][[1]]   }
  for (i in 1:length(reg)){ residuals[[names(reg[i])]] = reg[[names(reg[i])]][[2]]    }
  
  names(hawkes_kernel) <- names(dt_file_paths)
  hawkes_kernel <- organize_kernel_output(hawkes_kernel)
  list(hawkes_kernel, residuals)
}

estimate_hawkes <- function(event, state_func, delta, support_max, mc.cores=parallel::detectCores(), save.config=list()){
  event_names <- setdiff(colnames(event), "time")
  
  event$bin <- floor(event$time / delta)
  event_bin_count <- event[, lapply(.SD,sum), by=bin, .SDcols=event_names]
  
  event_dummy <- data.frame(bin=setdiff(min(event_bin_count$bin):max(event_bin_count$bin), event_bin_count$bin))
  for (coln in setdiff(colnames(event_bin_count), "bin")){
    event_dummy[[coln]] <- 0
  }
  event_bin_count <- rbind(event_bin_count, event_dummy)
  event_bin_count <- event_bin_count[order(bin), ]
  
  event_bin_count_matrix <- data.matrix(event_bin_count[,mget(setdiff(colnames(event_bin_count), "bin"))])
  
  event$bin <- NULL
  state_ts <- state_func(event_bin_count$bin * delta)
  rm(event ,event_bin_count, event_dummy)
  if(class(state_ts)=="list" && !is.null(save.config$save_dir) && !is.null(save.config$file_name)){
    bigdt_list <- get_bigdt_list(data_ts = event_bin_count_matrix, 
                                 p = floor(support_max/delta),
                                 state_ts = state_ts)
    dt_file_paths <- list()
    for(each_dt_name in names(bigdt_list)){
      dt_to_save <- bigdt_list[[each_dt_name]]
      dir.create(save.config$save_dir, showWarnings = FALSE)
      dt_file_paths[[each_dt_name]] <- paste0(save.config$save_dir, "/", save.config$file_name, ".",each_dt_name,"support",toString(support_max),"delta",toString(delta),".rds")
      saveRDS(dt_to_save, dt_file_paths[[each_dt_name]])
    }
    rm(bigdt_list, event_bin_count_matrix, state_ts)
    if(!is.null(save.config$save_bigdt_only) && save.config$save_bigdt_only == TRUE){
      return(NULL)
    }
    hawkes_kernel<- get_theta_cls_bigdt_fromfile(dt_file_paths, mc.cores = mc.cores)
    hawkes_kernel_est = hawkes_kernel[[1]]
    residuals = hawkes_kernel[[2]]
  }else{
    hawkes_kernel_est <- get_theta_cls(data_ts = event_bin_count_matrix, 
                                       p = floor(support_max/delta),
                                       state_ts = state_ts,
                                       mc.cores = mc.cores)
  }
  hawkes_kernel_est <- hawkes_kernel_est / delta
  time_scale <- (1:floor(support_max/delta)) * delta
  time_true <- seq(0, max(time_scale), length=101)
  hawkes <- list(
    hawkes_kernel_est = hawkes_kernel_est,
    delta = delta,
    support_max = support_max,
    residuals = residuals
  )
  return(hawkes)
}

plot_hawkes_kernel <- function(time_pints, phi_value, phi_true_func=NULL){
  plot(time_pints, phi_value, xlab="Time in seconds", ylab="Hawkes kernel value", ylim=range(c(0, phi_value), na.rm = TRUE))
  finite_idx <- which(is.finite(phi_value))
  if(length(finite_idx) >= 4){
    fitted_curve <- smooth.spline(time_pints[finite_idx], phi_value[finite_idx])
    lines(fitted_curve, col=2)
  }
  if(!is.null(phi_true_func)){
    time_true <- seq(0, max(time_pints), length=101)
    lines(time_true, sapply(time_true, phi_true_func), col=4)
  }
}

plot_hawkes <- function(hawkes, matrix_layout=TRUE){
  if(matrix_layout){
    layout(matrix(1:(ncol(hawkes$hawkes_kernel_est)^2), ncol=ncol(hawkes$hawkes_kernel_est)))
  }else{
    layout(1)
  }
  time_scale <- (1:floor(hawkes$support_max/hawkes$delta)) * hawkes$delta
  for(stimulater in colnames(hawkes$hawkes_kernel_est)){
    for(stimulatee in colnames(hawkes$hawkes_kernel_est)){
      kernel_this <- hawkes$hawkes_kernel_est[rownames(hawkes$hawkes_kernel_est)==paste0("design_matrix", stimulater),]
      if(!is.null(hawkes$phi_mat_func)){
        phi_func <- function(t) {
          ret <- phi_mat_func_idx(t, stimulater)
          ret[names(ret) == stimulatee]
        }
      }else{
        phi_func <- NULL
      }
      plot_hawkes_kernel(rev(time_scale), kernel_this[,stimulatee], phi_func)
      title(paste0("event of ",stimulater," stimulate ",stimulatee))
      
      # cols_all <- paste0(rep(sprintf("%+d", c(1:3, -(1:3))), each=3), "(", c("+", "-", "t"), ")")
      # cols_all <- c(cols_all, c("p+(+)", "p+(-)", "p+(t)", "p-(+)", "p-(-)", "p-(t)"))
      cols_all <- paste0(rep(sprintf("%+d", c(1:3, -(1:3))), each=3), "(", c("i", "c", "t"), ")")
      cols_all <- c(cols_all, c("p+(i)", "p+(c)", "p+(t)", "p-(i)", "p-(c)", "p-(t)"))
      cols_legend <- paste0(c("insertion", "deletion", "trade"), "@", rep(c("1st ask price", "2nd ask price", "3rd ask price",
                                                                            "1st bid price", "2nd bid price", "3rd bid price"), each=3))
      cols_legend <- c(cols_legend, c("mid price incease because bid insertion", 
                                      "mid price incease because ask deletion",
                                      "mid price incease because ask trade",
                                      "mid price decrease because ask insertion", 
                                      "mid price decrease because bid deletion",
                                      "mid price decrease because bid trade"))
      
      mtext(paste0(cols_legend[cols_all == stimulater], " -> ", cols_legend[cols_all == stimulatee]))
    }
  }
}
